from __future__ import absolute_import, division, print_function, \
    unicode_literals
import sys
sys.path.append("..")
import argparse
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms
from uuid import uuid1
from nfs.aiflow.report import *

from torch.utils.data import DataLoader

class DealDataset(Dataset):
    """
        读取数据、初始化数据
    """
    def __init__(self, data_name, label_name,transform=None):
        (train_set, train_labels) = load_data(data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
        self.train_set = train_set
        self.train_labels = train_labels
        self.transform = transform

    def __getitem__(self, index):

        img, target = self.train_set[index], int(self.train_labels[index])
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.train_set)


def load_data(data_name, label_name):
    """
        data_folder: 文件目录
        data_name： 数据文件名
        label_name：标签数据文件名
    """
    # lbpath = os.path.join(data_folder,label_name)
    y_train = np.load(label_name)

    # imgpath = os.path.join(data_folder, data_name)
    x_train = np.load(data_name)

    return (x_train, y_train)

class CNN(nn.Module):

    def __init__(self):

        super(CNN, self).__init__()

        self.layer1 = nn.Sequential(

            nn.Conv2d(1, 25, kernel_size=3),

            nn.BatchNorm2d(25),

            nn.ReLU(inplace=True)

        )

        self.layer2 = nn.Sequential(

            nn.MaxPool2d(kernel_size=2, stride=2)

        )

        self.layer3 = nn.Sequential(

            nn.Conv2d(25, 50, kernel_size=3),

            nn.BatchNorm2d(50),

            nn.ReLU(inplace=True)

        )

        self.layer4 = nn.Sequential(

            nn.MaxPool2d(kernel_size=2, stride=2)

        )

        self.fc = nn.Sequential(

            nn.Linear(50 * 5 * 5, 1024),

            nn.ReLU(inplace=True),

            nn.Linear(1024, 128),

            nn.ReLU(inplace=True),

            nn.Linear(128, 10)

        )

    def forward(self, x):

        x = self.layer1(x)

        x = self.layer2(x)

        x = self.layer3(x)

        x = self.layer4(x)

        x = x.view(x.size(0), -1)

        x = self.fc(x)

        return x

# train model
def train_model(data_file, output_dir,batch_size,learning_rate,num_epoches):
    """
    all file use absolute dir
    :param data_file: `train_test_data.txt` absolute dir
    data_file: 里面存的是mnist数据集的路径，eg:MNIST_data/
    :param output_dir:
    :return:
    """


    # 选择模型
    model = CNN()

    # if torch.cuda.is_available():
    #     print("CUDA is available")
    #     model = model.cuda()
    # 不使用cuda

    # 定义损失函数和优化器

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(), lr=learning_rate)


    # with open(data_file,'r') as f:
    #     data_list = f.readlines()[0].split(',')
    # 如果kubeflow自动读取txt文件内容，则注释上面两行，解开下面这行的注释
    data_list = data_file.split(',')

    train_dataset = DealDataset(data_list[0], data_list[2],
                                transform=transforms.ToTensor())
    test_dataset = DealDataset(data_list[1], data_list[3],
                               transform=transforms.ToTensor())


    # 训练数据和测试数据的装载
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,  # 一个批次可以认为是一个包，每个包中含有100张图片
        shuffle=True,
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    # 训练模型
    for i in range(num_epoches):
        epoch = 0

        for data in train_loader:

            img, label = data
            img = img.float()

            img = Variable(img)

            # if torch.cuda.is_available():
            #
            #     img = img.cuda()
            #
            #     label = label.cuda()
            # 不使用cuda



            img = Variable(img)

            label = Variable(label)


            out = model(img)

            loss = criterion(out, label)

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

            epoch += 1

        # 模型评估
        model.eval()

        eval_loss = 0

        eval_acc = 0

        for data in test_loader:

            img, label = data

            img = img.float()

            img = Variable(img)

            # if torch.cuda.is_available():
            #     img = img.cuda()
            #
            #     label = label.cuda()
            # 不使用cuda

            out = model(img)

            loss = criterion(out, label)

            eval_loss += loss.data.item() * label.size(0)

            _, pred = torch.max(out, 1)

            num_correct = (pred == label).sum()

            eval_acc += num_correct.item()

        print('EPOCH: ', i + 1)
        print('Test Loss: {:.6f}, Acc: {:.6f}'.format(

            eval_loss / (len(test_dataset)),

            eval_acc / (len(test_dataset))
        ))
        # 汇报模型每个epoch的损失和预测准确度（epoch,loss,accuracy）
        # report_dl_epoch(i, eval_loss / (len(test_dataset)), eval_acc / (len(test_dataset)))
        i += 1

    # 保存模型

    model_file = str(uuid1()) + ".pt"

    torch.save(model.state_dict(), output_dir + "/" + model_file)
    print("### save model done.")

    with open(output_dir + 'model.txt', 'w') as f:
        f.write(output_dir + model_file)
    print("### write trained model path and name to: model.txt done.")

    # 汇报模型保存地址
    # report_data(100, 100, "asfasf", output_dir + "/" + model_file)


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Kubeflow MNIST train model script')
    parser.add_argument('--data_dir', type=str, required=True, help='output dir')
    parser.add_argument('--data_file', type=str, required=True, help='a file write dataset file dir')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--epoches', type=int, default=20, help='num epoches')
    args = parser.parse_args()
    return args


def run():
    args = parse_arguments()
    train_model(args.data_file, args.data_dir,args.batch_size, args.lr,args.epoches)


if __name__ == '__main__':
    run()
